Predicting Stock Market Volatility Using Graph Attention Networks
Capturing Inter-Stock Dynamics for Enhanced Accuracy
Authors
OPTIVER STREAM, GROUP 21:
Shreya Prakash (520496062)
Chenuka Garunsinghe (530080640)
Binh Minh Tran (530414672)
Enoch Wong (530531430)
Ruohai Tao (540222281)
Zoha Kausar (530526838)
1. Executive Summary
Realized volatility (RV), measuring stock price fluctuations, is vital for financial risk management but difficult to predict due to inter-stock dependencies. This study demonstrates that a Graph Attention Network (GAT) can model these relationships dynamically, achieving an RMSE of 1.236 x 10-3, surpassing traditional models like HAR-RV. The GAT enhances accuracy and interpretability by revealing sector-based influences. Deployed in the VoltaTrade Shiny app, this solution enables investors and stock traders to access real-time, interactive RV predictions, helping them assess risks, optimize portfolios, and refine trading strategies. Combining finance, artificial intelligence (deep learning), and data science, this interdisciplinary work delivers a scalable forecasting tool.
Innovations: This research introduces three key advancements:
Utilization of a Graph Attention Network to model stock relationships dynamically, unlike static traditional methods.
Incorporation of inter-stock dynamics to enhance forecasting accuracy through attention-based weighting.
Deployment in a Shiny application, enabling user-friendly, real-time volatility analysis for broader accessibility.
2. Introduction
The rise of stock trading in the 2010s, amplified in the 2020s by zero-commission platforms and social media movements, has heightened the need for reliable forecasting tools (Governors of the Federal Reserve System, 2022). Realized volatility (RV), which quantifies stock price fluctuations over a specific period, is a cornerstone for risk management, derivative pricing, and portfolio optimization in financial markets (Andersen et al., 2003); (Hull, 2015).
However, most traditional models used to forecast volatility - such as GARCH or HAR-RV - assume linearity and treat stocks as independent entities, ignoring the non-linear dynamics and cross-sectional dependencies that govern real-world markets (Corsi, 2009); (Cont, 2001). This creates a performance ceiling in modern settings, where market behaviour is complex, high-frequency, and increasingly interlinked. This formed our research question:
Research Question Can graph neural networks identify inter-stock relationships using market metrics to improve RV forecasting accuracy over traditional models?
To address this, we developed a graph-based deep learning framework that models the stock market as an evolving graph, where nodes represent stocks, and edges reflect historical similarities. Using a Graph Attention Network (GAT), we dynamically weigh inter-stock influences, enhancing prediction accuracy and interpretability (Velickovic et al., 2018); (Zhang et al., 2022).
We also deployed this GAT-based pipeline into a real-time Shiny web-application which empowers individual investors and quantitative-traders by allowing them to interact with volatility predictions, visually explore inter-stock relationships, and filter based on configurable metrics - making a traditionally opaque process more transparent, interpretable, and user-centric.
Inspired by the Kaggle Optiver Realized Volatility Prediction Challenge(Optiver, 2021), we focus on building a practical, and scalable end-to-end volatility prediction system that integrates cross-disciplinary theory of -
Market behaviour and volatility mechanisms from Finance.
Modelling of complex non-linear dependencies from Deep Learning (AI).
Full modelling life-cycle, from data cleaning to real-time deployment from Data Science.
3. Methodology
Image 1: Overview of Data Analysis Process
3.1 Data Pre-processing
3.1.1 Optiver LOB dataset
Show Code
# list of all importsfrom glob import globimport pandas as pd, osimport numpy as npimport torchimport matplotlib.pyplot as pltfrom sklearn.decomposition import PCAfrom sklearn.preprocessing import StandardScalerfrom sklearn.linear_model import LinearRegressionfrom sklearn.ensemble import RandomForestRegressorfrom sklearn.ensemble import GradientBoostingRegressorfrom itertools import productimport warningsfrom pandas.errors import PerformanceWarning# Suppress fragmentation warnings from pandaswarnings.filterwarnings("ignore", category=PerformanceWarning)# list all book files and the target tablereal_volatility = pd.read_csv('real_volatility.csv')book_paths =sorted(glob("individual_book_train/stock_*.csv"))df_files = (pd.DataFrame({"path": book_paths}) .assign(stock_id=lambda d: d["path"].astype(str) .str.extract(r'(\d+)').astype(int)))device = torch.device('mps'if torch.mps.is_available() else'cpu') # can change mps to cuda for non metal devices
We use the Optiver limit-order-book (LOB) training set: 120 stocks × 3830 ten-minute buckets, each bucket containing 600 one-second snapshots of the top-of-book. The public train.csv file provides realized-volatility targets; all other fields come from the per-stock book_*.csv tables.
3.2.2 De-normalization & cleaning
For every stock ID and time ID we forward-fill the 600 snapshots, compute the smallest non-zero jump \(\delta P\), and rescale the whole bucket by \(\frac{0.01}{\delta P}\), thereby recovering genuine dollar prices.
Missing data is rare; when a stock misses > 0.05 % of seconds on any trading day (≈ 44/88 200) we drop it. For the remaining series we Winsorise at the 0.1 % / 99.9 % quantiles to suppress single-tick glitches.
Because the time-id values are shuffled, we restore chronology with a 1-D spectral embedding: treat each bucket as a point in ℝ¹²⁰, embed with the leading Laplacian eigen-vector, and sort. Applied to S&P-100 closes this technique reproduces calendar order perfectly, so we adopt it here.
Show Code
from sklearn.manifold import SpectralEmbeddingfrom sklearn.preprocessing import minmax_scaleimport yfinance as yfdef spectral_order(df, k=30, seed=42):""" Return index sorted by the leading spectral coordinate. df : (n_buckets × n_stocks) price matrix with *no* NaNs. """ df_clean = df.fillna(df.mean()) X = minmax_scale(df_clean.values) # normalise emb_2d = SpectralEmbedding(random_state=seed).fit_transform(X) coord = emb_2d[:, 0]return df.index[coord.argsort()]THRESHOLD =0.0005keep = df_prices.isna().mean().le(0.0005)df_prices = df_prices.loc[:, keep]# winsoriseq_lo, q_hi = df_prices.quantile(0.001), df_prices.quantile(0.999)df_prices_denorm_clean = df_prices.clip(lower=q_lo, upper=q_hi, axis=1).ffill().bfill()time_id_ordered = spectral_order(df_prices_denorm_clean)df_prices_ordered = df_prices_denorm_clean.reindex(time_id_ordered)
Show Code
import contextlibimport sysimport ioimport warningsfrom sklearn.manifold import SpectralEmbeddingfrom sklearn.preprocessing import minmax_scaleimport yfinance as yfimport pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.dates as mdatesf = io.StringIO()with warnings.catch_warnings(), contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): warnings.simplefilter("ignore")# Download S&P-100 data sp100 = pd.read_html("https://en.wikipedia.org/wiki/S%26P_100")[2].Symbol sp100 = sp100.str.replace('.', '-', regex=False) df_real = (yf.download(sp100.to_list(), start="2020-01-01", end="2021-06-01", interval="1d")['Close'] .dropna(axis=1, thresh=0.5*len(sp100)) .dropna())# Embed both matrices in 2-D for eyeballingembed = SpectralEmbedding(n_components=2, random_state=42)Z_denorm = embed.fit_transform(minmax_scale(df_prices_ordered.values))Z_real = embed.fit_transform(minmax_scale(df_real.values))import matplotlib.pyplot as pltfig, ax = plt.subplots(1, 2, figsize=(16, 6))# Optiver Bucketssc0 = ax[0].scatter(Z_denorm[:, 0], Z_denorm[:, 1], c=np.arange(len(Z_denorm)), cmap='viridis', s=8)cb0 = fig.colorbar(sc0, ax=ax[0], shrink=0.8, pad=0.03)cb0.set_label("Spectral Order Index", fontsize=11, labelpad=10)_ = ax[0].set_title("Optiver Buckets\n(Colour = Spectral Order)", fontsize=13)_ = ax[0].set_xlabel("Spectral Dimension 1", fontsize=11)_ = ax[0].set_ylabel("Spectral Dimension 2", fontsize=11)# S&P-100 Dailysc1 = ax[1].scatter(Z_real[:, 0], Z_real[:, 1], c=mdates.date2num(df_real.index), cmap='viridis', s=8)cb1 = fig.colorbar(sc1, ax=ax[1], shrink=0.8, pad=0.03)cb1.set_label("TimeID order (Calender Data Progression)", fontsize=11, labelpad=10)_ = ax[1].set_title("S&P-100 Daily\n(Colour = Calendar Date)", fontsize=13)_ = ax[1].set_xlabel("Spectral Dimension 1", fontsize=11)_ = ax[1].set_ylabel("Spectral Dimension 2", fontsize=11)# Figure title and spacing_ = fig.suptitle("Spectral Embedding of Optiver vs S&P-100 Stock Closes", fontsize=12, y=1.02)fig.subplots_adjust(wspace=0.25, top=0.85)plt.tight_layout(rect=[0, 0, 1, 1])plt.show()
Figure 1: Spectral Embedding of Optiver Price Buckets vs. S&P‑100 Daily Closes (2020‑2021)
Finally, we turn prices into stationary log-returns:
log-diff RV = log Pₜ − log Pₜ₋₁. After this transform the Zivot–Andrews test fails to reject stationarity for every stock, giving us a stable target for modelling. The resulting panel (T − 1 = 3 829 × N = 120) underpins all subsequent feature engineering.
Show Code
#| echo: true#| warning: false#| message: false#| output: truedef compute_rv(grp):"""Compute realized volatility from intraday price data"""ifall(col in grp.columns for col in ['bid_price1', 'ask_price1', 'bid_size1', 'ask_size1']): wap = (grp['bid_price1'] * grp['ask_size1'] + grp['ask_price1'] * grp['bid_size1']) /\ (grp['bid_size1'] + grp['ask_size1'])else: wap = grp['ask_price1'] log_returns = np.log(wap).diff().dropna() rv = np.sqrt((log_returns **2).sum())return rv# RV for each stock and time_id combinationrv_records = []for _, row in df_files.iterrows():try: dfb = pd.read_csv(row.path) rv_series = (dfb.groupby('time_id') .apply(compute_rv, include_groups=False) .rename('rv') .reset_index()) rv_series['stock_id'] = row.stock_id rv_records.append(rv_series)exceptExceptionas e:print(f"Error processing {row.path}: {e}")continuerv_df = pd.concat(rv_records, ignore_index=True)# map to bucket_idx using time orderingtime_map = pd.DataFrame({'time_id': time_id_ordered})time_map['bucket_idx'] =range(len(time_map))rv_df = rv_df.merge(time_map, on='time_id')rv_pivot = rv_df.pivot(index='bucket_idx', columns='stock_id', values='rv')avg_rv = rv_pivot.mean(axis=1)
Show Code
# Stationarity testsfrom statsmodels.tsa.stattools import zivot_andrewsdef test_stationarity(ts, name="Series"):"""Test stationarity using Zivot-Andrews test (better for structural breaks)""" ts_clean = ts.dropna()try: za_stat, za_pval, za_cv, za_lag, za_bpidx = zivot_andrews( ts_clean.values, regression='ct', # break in both intercept and trend trim=0.15, maxlag=12, autolag='AIC' ) is_stationary = za_stat < za_cv['5%'] status ='✓ Stationary'if is_stationary else'✗ Non-stationary'print(f"{name:15} | ZA: {za_stat:6.3f} (p={za_pval:.3f}) | {status}")print(f"{'':15} | Break at index: {za_bpidx} | 5% crit: {za_cv['5%']:6.3f}")return is_stationaryexceptExceptionas e:print(f"{name:15} | ZA test failed: {str(e)}")returnFalsetest_stationarity(avg_rv, "Average RV by Time")
Average RV by Time | ZA: -10.702 (p=0.001) | ✓ Stationary
| Break at index: 3024 | 5% crit: -5.073
np.True_
Show Code
from statsmodels.tsa.stattools import zivot_andrewslog_rv = np.log(avg_rv +1e-8)stationary_rv = log_rv.diff().dropna()
Now that this transformation has worked as seen on the transformed volatility by time id plot we apply this transformation for each individual stock across its time_id
This section is further explained in detail in Appendix 9.2
3.2 Feature engineering
We need to enrich the transformed_rv_pivot () where T is the number of time buckets and N is the number of stocks with more detail to increase the information gain of the data for the GAT model to learn both temporal pattern and short term fluctuations. In order to achieve this we proceeded with the following features (See appendix 9.2 for more detail on features):
Own lags (3) – Rvᵢ,ₜ₋₁ … Rvᵢ,ₜ₋₃ capture short-range persistence.
In this section we describe how we turned the above normalized price matrix into a static stock stock-neighbour graph, assemble node-level features and the apply a two layer Graph Attention Network for bucket‐by‐bucket volatility forecasting.
3.3.1 Building the neighbour graph and GAT
What & why. For every stock we embed the most-recent 50-bucket price signature, this was done because stock relationships change over time! Stocks that moved together 2 years ago might not move together now (A logical assumption we made). and use a KD-tree to find its K=3 nearest neighbors. This captures current co-movement, recognizing that relationships drift over time.
Returns. A PyG-ready edge_index (source–destination pairs) and exponentially decaying edge_weight, plus the raw neighbor matrix and the list of stocks that survived NaN screening and with this, we captured 88% of the Optiver universe within the data
Show Code
from sklearn.preprocessing import MinMaxScalerfrom sklearn.neighbors import KDTreeimport torchimport numpy as npdef build_graph_on_features(X_features, time_window=50, K=3): T, N, F = X_features.shape# Use recent time window for similarity (stocks change over time)if time_window < T: recent_features = X_features[-time_window:, :, :] # Last 50 time stepselse: recent_features = X_features X_for_graph = recent_features.transpose(1, 0, 2).reshape(N, -1)# Remove stocks with missing features valid_stocks =~np.isnan(X_for_graph).any(axis=1) X_clean = X_for_graph[valid_stocks] valid_indices = np.where(valid_stocks)[0]print(f" Valid stocks: {len(valid_indices)} / {len(valid_stocks)}")# Min-Max scale feature space and build tree X_scaled = MinMaxScaler().fit_transform(X_clean) tree = KDTree(X_scaled, metric='euclidean') dist, nbr_raw = tree.query(X_scaled, k=K+1) # includes self# mapping back to original indices nbr = valid_indices[nbr_raw] src = np.repeat(valid_indices, K) dst = nbr[:, 1:].ravel() edge_index = torch.tensor([src, dst], dtype=torch.long) edge_weight = torch.exp(-torch.tensor(dist[:, 1:].ravel(), dtype=torch.float))return edge_index, edge_weight, nbr, valid_indicesedge_index, edge_weight, neighbor_indices, valid_indices = build_graph_on_features( X_initial_features, time_window=50, K=3)
Valid stocks: 112 / 112
Show Code
import numpy as npimport matplotlib.pyplot as pltimport networkx as nximport torchdef plot_kdtree_network(edge_index, edge_weight, valid_indices):ifisinstance(edge_index, torch.Tensor): edge_index = edge_index.cpu().numpy()ifisinstance(edge_weight, torch.Tensor): edge_weight = edge_weight.cpu().numpy() G = nx.Graph()for i, stock_idx inenumerate(valid_indices): G.add_node(i, stock_id=stock_idx) sources = edge_index[0] targets = edge_index[1]for i inrange(len(sources)): src = sources[i] tgt = targets[i] weight = edge_weight[i] G.add_edge(src, tgt, weight=weight) plt.figure(figsize=(10, 8)) pos = nx.spring_layout(G, k=3, iterations=50, seed=42) nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=1000, linewidths=2, edgecolors='darkblue') edges = G.edges() weights = [G[u][v]['weight'] for u, v in edges]# Normalize weights for edge thickness max_weight =max(weights) if weights else1 edge_widths = [3* (w / max_weight) for w in weights] nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.7, edge_color='gray') labels = {i: f'{valid_indices[i]}'for i inrange(len(valid_indices))} nx.draw_networkx_labels(G, pos, labels, font_size=10, font_weight='bold') plt.title('Stock Relationship Network (KD-Tree)', fontsize=14, fontweight='bold') plt.axis('off') plt.figtext(0.02, 0.02, f'Nodes: {len(valid_indices)}\nEdges: {len(edges)}\nAvg Weight: {np.mean(weights):.3f}', fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) plt.tight_layout() plt.show()return G
What & why. Using the neighbor matrix below, we compute for every bucket t the average log-diff RV of each stock’s three neighbours and stack it with the six temporal channels built in § 4.1. This single cross-sectional feature gives the GAT some market “consensus” signal without inflating the feature dimension. We add the six temporal channels and the one neighbour channel to X_transformed converting it to a tensor for later training.
Show Code
def build_neighbour_feats(stationary_rv, neighbor_indices):""" Build features properly with correct dimensions """ T, N = stationary_rv.shape# Neighbor mean RV (1 feature) nei_mean = np.zeros((T, N, 1))for i inrange(N):if i <len(neighbor_indices) andlen(neighbor_indices[i]) >1: nei_idx = neighbor_indices[i, 1:] # exclude self nei_idx = nei_idx[nei_idx < N] # ensure valid indicesiflen(nei_idx) >0: nei_mean[:, i, 0] = stationary_rv[:, nei_idx].mean(axis=1)# Clean up X = np.nan_to_num(nei_mean, nan=0.0, posinf=0.0, neginf=0.0)print(f"Features created: {X.shape} (T × N × {X.shape[2]} features)")return X, stationary_rvX_transformed, y_transformed = build_neighbour_feats(transformed_rv_pivot, neighbor_indices)
Features created: (3829, 112, 1) (T × N × 1 features)
Then to learn the dynamic weightings of the neighbours we defined a simple 2-layer GAT that lets each stock fuse its own history with weighted neighbour signals.
Show Code
import torchimport torch.nn.functional as Ffrom torch_geometric.nn import GATConvclass ImprovedVolatilityGAT(torch.nn.Module):def__init__(self, in_feats, hidden=64, heads=4, dropout=0.3):super().__init__()self.dropout = dropout# Smaller, more stable architectureself.conv1 = GATConv(in_feats, hidden, heads=heads, dropout=dropout, concat=True)self.conv2 = GATConv(hidden * heads, hidden //2, heads=2, dropout=dropout, concat=True)self.conv3 = GATConv(hidden, 1, heads=1, concat=False, dropout=dropout)# Batch normalization for stabilityself.bn1 = torch.nn.BatchNorm1d(hidden * heads)self.bn2 = torch.nn.BatchNorm1d(hidden)self._initialize_weights()def _initialize_weights(self):"""Better weight initialization"""for m inself.modules():ifisinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight)if m.bias isnotNone: torch.nn.init.constant_(m.bias, 0)def forward(self, x, edge_index):# Layer 1 h =self.conv1(x, edge_index) h =self.bn1(h) if h.size(0) >1else h # Skip BN for single samples h = F.elu(h) h = F.dropout(h, p=self.dropout, training=self.training)# Layer 2 h =self.conv2(h, edge_index) h =self.bn2(h) if h.size(0) >1else h h = F.elu(h) h = F.dropout(h, p=self.dropout, training=self.training)# Output layer h =self.conv3(h, edge_index)return h.squeeze(-1)
3.3.2 Training, Loss and Cross Validation
We decided to use a hybrid approach for an optimum loss function. We first scale the means squared error by the target series own variance which will give a unit free % of variance magnitude penalty, then mix with a directional penalty that rises when the model get the sign of the volatility change wrong. We used a default weight of \(\alpha = 0.8\) so the network is rewarded for predicting both how big and which way volatility moves—ideal for trading settings.
Why 0.8: (Yin, 2023) introduce a *direction-integrated MSE* for stock forecasting and test λ ∈ {0.6–0.9}; their best models cluster at λ≈0.8. This was found from our literature review which we used as guidance here.
Training hyper-parameters were chosen by a small expanding-window grid search (Appendix 9.5) and fall squarely within values recommended by recent GAT studies: Adam with \(\text{lr} = 1 \times 10^{-3}\) and \(\text{weight-decay} = 1 \times 10^{-4}\) gave the lowest mean CV loss; \(\text{dropout} = 0.20\) and gradient-clipping at \(|g|_2 \leq 1.0\) eliminated over-fitting and exploding gradients; a ReduceLROnPlateau scheduler (factor \(0.7\), patience \(5\)) and early-stopping (patience \(15\), \(\delta = 1 \times 10^{-6}\)) cut training time by \(\sim 30%\) without degrading validation loss. Together these settings provide numerically stable, reproducible training while matching the error profiles demanded by our hybrid loss.
Show Code
#| echo: false#| warning: false#| message: false#| output: falsedevice = torch.device('cpu')#device = torch.device('mps' if torch.mps.is_available() else 'cpu') # can change mps to cuda for non metal devicesX_initial_features = build_initial_features(transformed_rv_pivot)
Features created: (3829, 112, 1) (T × N × 1 features)
Show Code
# ADD THIS LINE:X_complete = np.concatenate([X_initial_features, X_neighbor_features], axis=2)# UPDATE YOUR TRAINING:results = train_gat_fixed(X_complete, y_transformed, edge_index, edge_weight, device)
Saved best model with validation loss: 17.679580
3.4 Baseline models
To benchmark the performance of the GAT model, diverse baseline models were used including both traditional models and machine learning approaches. Log transformation of volatility was used across all baseline models to stabilize variance and improve model interpretability. These baseline models were served to evaluate whether GAT could capture patterns better than other models.
Evaluate linear relationship using lags, moving average and trend
Best Tuning: 5 windows
Show Code
import osfrom pathlib import Path# Use current directory as save pathsave_path = Path('.')# All model panelsmodels_panels = {'LAG': lag_panel,'HAR_RV': har_panel, 'Linear': linear_panel,'PCA_Linear': pca_linear_panel,'Random_Forest': rf_panel,'Gradient_Boosting': gb_panel}# 1. Save all model panels to CSVfor model_name, panel in models_panels.items(): filename =f"{model_name}_predictions.csv" filepath = save_path / filename panel.to_csv(filepath)
3.5 Evaluation protocol & metrics
The evaluation metrics that are used for both the GAT model and baseline model to evaluate the performance of the GAT model in capturing complex inter-stock relationships.
3.5.1 Root Mean Squared Percentage Error (RMSPE)
Measure the average squared difference between prediction and actual values in percentage format
It is scale-independent means it would provide comparable results across different size of datasets \[
\text{RMSPE} = \sqrt{ \frac{1}{n} \sum_{t=1}^n \left( \frac{y_t - \hat{y}_t}{y_t} \right)^2 }
\]
3.5.2 Quantile Likelihood (QLIKE)
Assess the quality of volatility scale estimation by penalizing misestimation of variance
It is sensitive to risk \[
\text{QLIKE} = \frac{1}{n} \sum_{t=1}^n \left( \frac{y_t}{\hat{y}_t} - \log \left( \frac{y_t}{\hat{y}_t} \right) - 1 \right)
\]
3.5.3 Mean Absolute Percentage Error (MAPE)
Measure the average difference between prediction and actual values in percentage format
It is easy to interpret and it is scale-independent \[
\text{MAPE} = \frac{1}{n} \sum_{t=1}^n \left| \frac{y_t - \hat{y}_t}{y_t} \right|
\]
3.5.4 Data - Splitting
For each individual stock, it would split the dataset into 80% training, 10% validation and 10% testing. The training set was used to fit the model. The validation set was used for hyper-parameter tuning while the test set was used to evaluate the final performance. This walk-forward split ensures that the future information would not leak into the past, maintaining data integrity.
import pandas as pdimport numpy as npimport torchfrom torch_geometric.data import Datadef diagnose_data_ranges():""" Diagnose the data ranges to identify the transformation issue """ gat_pred = pd.read_csv('GAT_predictions.csv', index_col=0) transformed_rv = create_stationary_features_fixed(rv_pivot)return gat_pred, transformed_rvdef regenerate_gat_predictions_correctly():""" Regenerate GAT predictions with CORRECT transformations """ total_time_steps =len(rv_pivot) test_start =int(total_time_steps *0.8) test_rv_pivot = rv_pivot.iloc[test_start:]#print(f"Test data shape: {test_rv_pivot.shape}") test_transformed_rv = create_stationary_features_fixed(test_rv_pivot) # log + diff X_test_temporal = build_initial_features(test_transformed_rv) # 6 features X_test_neighbor, _ = build_neighbour_feats(test_transformed_rv, neighbor_indices) # 1 feature X_test = np.concatenate([X_test_temporal, X_test_neighbor], axis=2) # 7 features#print(f"Test features shape: {X_test.shape}") checkpoint = torch.load('best_gat_model.pt', map_location=device) model = ImprovedVolatilityGAT( in_feats=7, hidden=32, heads=2, dropout=0.2 ).to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() T, N, feat_dim = X_test.shape X_tensor = torch.tensor(X_test, dtype=torch.float) log_diff_predictions = []#print("Generating log-diff predictions...")with torch.no_grad():for t inrange(T): graph = Data( x=X_tensor[t], edge_index=edge_index, edge_weight=edge_weight ).to(device)# Get log-diff predictions (what the model was trained on) log_diff_pred = model(graph.x, graph.edge_index) log_diff_predictions.append(log_diff_pred.cpu().numpy()) log_diff_array = np.array(log_diff_predictions) adjusted_test_indices = test_rv_pivot.index[1:] # Account for diff operation log_diff_df = pd.DataFrame( log_diff_array, index=adjusted_test_indices, columns=test_rv_pivot.columns )#print(f"Log-diff predictions shape: {log_diff_df.shape}")#print(f"Log-diff range: [{log_diff_df.values.min():.6f}, {log_diff_df.values.max():.6f}]") log_diff_df.to_csv('GAT_log_diff_predictions.csv')return log_diff_df, test_rv_pivotdef calculate_metrics_correctly(log_diff_predictions, test_rv_pivot):""" Calculate metrics with CORRECT inverse transformations """ test_transformed_actual = create_stationary_features_fixed(test_rv_pivot) adjusted_indices = test_rv_pivot.index[1:] # Account for diff operation actual_log_diff_df = pd.DataFrame( test_transformed_actual, index=adjusted_indices, columns=test_rv_pivot.columns )#print(f"Actual log-diff shape: {actual_log_diff_df.shape}")#print(f"Actual log-diff range: [{actual_log_diff_df.values.min():.6f}, {actual_log_diff_df.values.max():.6f}]")# Align data common_indices = log_diff_predictions.index.intersection(actual_log_diff_df.index) common_columns = log_diff_predictions.columns.intersection(actual_log_diff_df.columns) pred_aligned = log_diff_predictions.loc[common_indices, common_columns] actual_aligned = actual_log_diff_df.loc[common_indices, common_columns]#print(f"Aligned data shape: {pred_aligned.shape}")#print(f"Common time points: {len(common_indices)}")#print(f"Common stocks: {len(common_columns)}") Y_pred_logdiff = pred_aligned.values Y_true_logdiff = actual_aligned.values mask =~(np.isnan(Y_true_logdiff) | np.isnan(Y_pred_logdiff)) Y_pred_clean = Y_pred_logdiff[mask] Y_true_clean = Y_true_logdiff[mask]#print(f"Valid data points: {len(Y_true_clean):,}")iflen(Y_true_clean) >0: rmse_logdiff = np.sqrt(np.mean((Y_pred_clean - Y_true_clean)**2)) mse_logdiff = np.mean((Y_pred_clean - Y_true_clean)**2) mae_logdiff = np.mean(np.abs(Y_pred_clean - Y_true_clean))# Method 2: Convert to RAW RV space for comparison with baselines# Convert log-diff predictions to raw RV# Step 1: exp() to get the multiplicative factors# Step 2: Need to reconstruct the cumulative log values Y_pred_raw = np.exp(Y_pred_clean) Y_true_raw = np.exp(Y_true_clean) eps =1e-12 Y_pred_raw = np.maximum(Y_pred_raw, eps) Y_true_raw = np.maximum(Y_true_raw, eps) rmse = np.sqrt(np.mean((Y_pred_raw - Y_true_raw)**2)) rmspe = np.sqrt(np.mean(((Y_pred_raw - Y_true_raw) / Y_true_raw)**2)) *100 ratio = Y_true_raw / Y_pred_raw qlike = np.mean(ratio - np.log(ratio) -1) mape = np.mean(np.abs(Y_pred_raw - Y_true_raw) / Y_true_raw) *100print(f"\nFINAL GAT PERFORMANCE:")print(f"RMSE: {rmse:.6f}")print(f"RMSPE: {rmspe:.2f}%")print(f"QLIKE: {qlike:.6f}")print(f"MAPE: {mape:.2f}%")return {'RMSE': rmse,'RMSPE': rmspe,'QLIKE': qlike,'MAPE': mape,'log_diff_rmse': rmse_logdiff if'rmse_logdiff'inlocals() elseNone }def main_correct_evaluation():""" Run the complete corrected evaluation """ diagnose_data_ranges() log_diff_pred, test_rv = regenerate_gat_predictions_correctly() metrics = calculate_metrics_correctly(log_diff_pred, test_rv)return metrics# Run the corrected evaluationcorrected_metrics = main_correct_evaluation()
Input matrix: 3830 time points × 112 stocks
Input matrix: 766 time points × 112 stocks
Initial features: (765, 112, 6) (T × N × 6 features)
Features created: (765, 112, 1) (T × N × 1 features)
Input matrix: 766 time points × 112 stocks
FINAL GAT PERFORMANCE:
RMSE: 0.626894
RMSPE: 34.35%
QLIKE: 0.044541
MAPE: 22.52%
Show Code
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom pathlib import Pathfrom matplotlib import gridspecmodel_names = ['HAR_RV', 'LAG', 'Linear', 'PCA_Linear','Random_Forest', 'Gradient_Boosting', 'GAT']# colour palette model_colours = {'HAR_RV' : "#0707E7",'LAG' : "#FB7D07",'Linear' : "#15F115",'PCA_Linear' : "#F40808",'Random_Forest' : "#8518EA",'Gradient_Boosting': "#8C564B",'GAT' : "#ED09A8"}actual_colour ="#030202"def load_and_align_data(): model_preds = {} shared_stocks =None shared_dates =None# Load all modelsfor model in model_names:if model =='GAT': preds = pd.read_csv('GAT_prediction_panel.csv', index_col=0)else: preds = pd.read_csv(f'{model}_predictions.csv', index_col=0) model_preds[model] = predsif shared_stocks isNone: actual = pd.read_csv('real_volatility.csv') shared_stocks =list(set(preds.columns).intersection(set(actual.columns))) shared_dates = preds.index.tolist()else: shared_stocks =list(set(shared_stocks).intersection(set(preds.columns))) shared_dates =list(set(shared_dates).intersection(set(preds.index.tolist()))) aligned_actual = actual.reindex(index=shared_dates) aligned_actual1 = np.exp(aligned_actual)return model_preds, aligned_actual1, shared_stocks, shared_datesdef create_comparison_plots(): preds, actual, stocks, dates = load_and_align_data() avg_actual = actual.loc[dates, stocks].mean(axis=1).values time_index = np.arange(len(avg_actual)) fig = plt.figure(figsize=(14, 10)) gs = gridspec.GridSpec(3, 3, height_ratios=[1.4, 1, 1], hspace=.45, wspace=.30)# function to plot def plot_one(ax, name): avg_pred = preds[name].loc[dates, stocks].mean(axis=1).values ax.plot(time_index, avg_actual, lw=1.6, label='Actual', color=actual_colour) ax.plot(time_index, avg_pred, lw=1.2, label=name, color=model_colours[name]) ax.set_xlabel('Time Index') ax.set_ylabel('Average volatility') ax.grid(True, ls='--', alpha=.6)# GAT ax_gat = fig.add_subplot(gs[0, :]) plot_one(ax_gat, 'GAT') ax_gat.set_title('GAT', fontweight='bold', fontsize=12)# Six Baseline baseline_axes = [fig.add_subplot(gs[r, c]) for r in (1, 2) for c inrange(3)] baseline_model = [m for m in model_names if m !='GAT']for ax, model inzip(baseline_axes, baseline_model): plot_one(ax, model) ax.set_title(model, fontsize=10)# Create custom Legend Legend = [] # This is for actual value actual_legend = plt.Line2D( [0], [0], color=actual_colour, linewidth=1.6, label='Actual' ) Legend.append(actual_legend)# This is for all modelsfor model in model_names: model_legend = plt.Line2D( [0], [0], color=model_colours[model], linewidth=1.2, label=model ) Legend.append(model_legend)# Put it in the figure legend fig.legend( handles=Legend, loc='lower center', ncol=4, frameon=False, # no border box fontsize=9, bbox_to_anchor=(0.5, -0.02) ) fig.suptitle('Model Predictions vs Actual Value', fontsize=14) plt.tight_layout(rect=[0, 0, 1, 0.94]) plt.show()# run itcreate_comparison_plots()
Figure 2: Actual-vs-Predicted Plots for All Models
We evaluate the Graph Attention Network (GAT) model on the test set using standard volatility forecasting metrics: Root Mean Squared Percentage Error (RMSPE), Mean Absolute Percentage Error (MAPE), and QLIKE loss, with RMSPE as the primary metric and lower values indicating better accuracy.
In the cross-sectional analysis of the last 10% of data, GAT demonstrates outstanding performance with an RMSE of 0.6, RMSPE of 0.344, QLIKE of 0.04, and MAPE of 0.224, as noted in the results table. Traditional models, such as HAR_RV and PCA_Linear, yield comparable results (RMSPE ~0.63+, QLIKE ~0.19), overall among the baseline models we saw PCA_linear and RF performing best showing that nonlinear models have advantages.
The findings underscore GAT’s effectiveness in modeling complex volatility patterns, while traditional models maintain competitive performance.
4.2 Evaluation Strategies Used
The GAT model followed 2 protocols for evaluation, during training it is assessed with a walk-forward, expanding-window cross-validation: four chronological folds created with TimeSeriesSplit. Within every fold we we we monitor with a hybrid loss function 80 % relative-MSE (MSE divided by target variance, expressed in %) and 20 % directional penalty (1 − sign-accuracy)—so the network is rewarded both for sizing volatility correctly and for getting the direction of change right.
We complete the training cycle with Adam (lr = 1 e-3, wd = 1 e-4), gradient clipping (‖g‖₂ ≤ 1), dropout = 0.2, batch-norm, ReduceLROnPlateau, and early stopping (δ = 1 e-6, patience = 15). Once we load the best state with lowest validation loss in the fold we evaluate on the unseen final 10% test window. reporting RMSPE, QLIKE, MAPE and RMSE—exactly the metrics required by volatility-forecasting literature.
Each baseline model is tuned separately but judged under the same time series split level playing field. Hyper-parameters are selected via grid search on the middle 10 % validation slice and the same metrics are computed on the held-out test slice. For models whose outputs are in log-space, predictions are exponentiated before scoring to keep scale consistent.
4.3 Interpretability
Linear models (HAR-RV, OLS): coefficients map cleanly to lagged RV, so drivers are transparent.
PCA + LR: noise-reduction helps accuracy, but latent components blur the link to raw features, hurting explainability.
Tree ensembles (RF, GB): capture non-linear effects yet produce forests of splits—useful variable rankings but little story. HAR-RV therefore remains the best accuracy/clarity compromise among baselines.
GAT goes further: every edge gets an attention weight αᵢⱼ, so we can read how much stock j moves stock i. Averaging α across the test set yields an “influence matrix” (Appendix 9.4) that lights up inside each GICS sector while down-weighting weak links. In practice the model says, e.g., tech stocks react mainly to other tech stocks; banks to banks.
Thus, despite its non-linear nature, the GAT provides sector-aware, quantitatively traceable explanations - something neither pure linear nor ensemble models can match.
5. Project Deployment & Interdisciplinary Impact
Deployment process:
The Shiny app is deployed on Posit Cloud with GitHub integration for continuous deployment. By linking the Posit Cloud project to a GitHub repository, it monitors the main branch for changes. When developers push commits, Posit Cloud triggers an automatic deployment pipeline, ensuring the live app reflects the latest stable version. This setup manages deployment complexities, allowing teams to focus on development while streamlining the workflow from commit to live application.
VoltaTrade is a Python Shiny web app for predicting and visualizing stock market volatility, featuring a modern dark-themed interface with glass-morphic design and smooth animations.
The application consists of five main modules (as shown in the figure below):
Overview Dashboard: A central hub displaying an overview of all features with interactive cards for easy navigation. This dashboard provides investors and traders with a quick snapshot of market conditions and app capabilities, helping them make informed decisions at a glance.
Model Details: An interactive laboratory comparing volatility prediction models, led by the GAT. It visualizes predicted versus actual value, feature importance, network influences between stocks, and real-time performance metrics. This module empowers users to understand and trust the predictive models, enabling them to select the most reliable tools for their trading strategies.
Stock Screener: Filters and ranks stocks based on financial metrics and volatility indicators to identify top performers.Investors and traders can efficiently discover promising stocks that match their risk and return preferences, streamlining the investment selection process.
Individual Stock Analysis: Provides in-depth analysis of single stocks with detailed volatility metrics, historical data visualization, and predictive insights.This tab allows users to conduct thorough due diligence on specific stocks, supporting more confident and data-driven investment decisions.
Stock Comparison: Enables side-by-side comparison of multiple stocks with visual indicators for volatility, returns, and key financial metrics. Traders and investors can easily compare potential investments, facilitating portfolio diversification and optimal asset allocation.
Portfolio Tracker: Tracks portfolio performance with advanced volatility analytics and risk assessment tools. This feature helps users monitor their holdings and manage risk, ensuring their portfolios align with their financial goals and market outlook.
Image 2: Shiny App Tabs
The app leverages a self-trained GAT predicting model (final model). It includes AI features through Open-AI integration, boasts a sophisticated UI with custom CSS animations, responsive design, and a modern sidebar navigation. The app processes volatility data and offers real-time analysis for financial decision-making.
6. Discussion
6.1 Key findings
Superior GAT Performance: The Graph Attention Network (GAT) excelled, achieving the lowest errors (QLIKE: 0.092, MAPE: 0.46, RMSPE: 0.57), reflecting precise point-wise predictions and enhanced volatility calibration compared to traditional models.
Inter-Stock Modeling: GAT’s attention-based graph structure captures temporal and cross-stock dependencies, dynamically weighting neighboring stocks’ influence, which improves generalization and predictive accuracy in complex market settings.
Interpretability Gains: The attention mechanism highlights influential stocks, such as sector-based peers (e.g., tech stocks impacting other tech stocks), offering transparency in financial forecasting.
Baseline Comparison: HAR-RV performed well during volatility shifts (RMSPE ~0.60) due to its hierarchical lag structure, but lacks cross-asset modeling, limiting its scope. PCA-Linear captured latent structures (RMSPE ~0.60) but reduced interpretability. Tree-based models (Random Forest, Gradient Boosting) under-performed (RMSPE >1.50), failing to adapt to regime shifts.
6.2 Limitations
Static Graph Structure: GAT’s graph, built on historical price similarity, remains fixed, unable to adapt to real-time market shifts (e.g., during crises), a constraint in dynamic financial environments.
Graph Construction Sensitivity: Performance hinges on neighbor selection (K=3), distance metrics, and price-based similarity, which may miss latent correlations, leading to degraded accuracy and instability across volatility regimes.
Limited Feature Space: Excluding macroeconomic indicators (e.g., interest rates) and news sentiment restricts generalizability, as volatility often reflects external shocks not captured by historical RV data.
Tick Structure Assumption: Reconstructing prices assumes a 0.01 tick increment, unsuitable for assets with varying tick sizes or liquidity, introducing bias in network science and financial modeling.
Computational Cost: GAT’s multi-head attention increases training demands, limiting real-time use in low-latency trading, a challenge in machine learning deployment.
Uncertainty Quantification: Lacking confidence intervals, GAT struggles in risk-sensitive financial contexts, reducing its practical utility for decision-making.
6.3 Future work
Dynamic Graphs: Implement adaptive graph construction to reflect evolving market conditions, improving inter-stock modeling in finance and network science.
Broader Features: Integrate macroeconomic indicators and news sentiment, enhancing data science robustness for external shock scenarios in financial markets.
Model Optimization: Apply pruning to reduce GAT’s computational load, enabling real-time applications, addressing machine learning efficiency.
Uncertainty Measures: Add predictive uncertainty to support risk-sensitive decisions, bridging finance and network science for practical deployment.
7. Conclusion
This study addresses whether neural networks can leverage inter-stock relationships from market metrics to enhance realized volatility forecasting accuracy compared to traditional models. Our Graph Attention Network (GAT) model confirms this, achieving superior performance (RMSPE=0.57) over models like HAR-RV. Deployed through the VoltaTrade Shiny application, GAT’s attention mechanism elucidates sector-specific influences, enhancing interpretability for financial decision-making. Integrating finance, machine learning, network science, and data science, this work offers a robust, accessible solution. Future enhancements include dynamic graph structures to capture evolving market dynamics, incorporation of macroeconomic indicators, model optimization for real-time use, and uncertainty quantification to support risk-sensitive applications, ensuring scalability and broader applicability.
8. Student Contribution
Shreya Prakash (520496062)
Shreya coordinated the integration of the team’s work, designing a user-friendly Shiny App prototype to showcase project results. She also drafted the project report and prepared impactful presentation slides to communicate findings clearly.
Chenuka Garunsinghe (530080640)
Chenuka recovered time IDs and ensured data stationarity to support accurate analysis. He developed and optimized the Graph Attention Network (GAT) model, enhancing the project’s predictive capabilities.
Enoch Wong (530531430)
Enoch created Figure 1 to visualize key data insights and trained baseline models to establish performance benchmarks. He contributed to the project report and presentation slides, ensuring clear communication of results.
Binh Minh Tran (530414672)
Binh developed the Shiny App, enabling interactive visualization of project outcomes. He also collaborated on the GAT model, improving its design and integration with project goals.
Zoha Kausar (530526838)
Zoha crafted visually engaging presentation slides to effectively convey the team’s findings. She also drafted and refined the project report, ensuring clarity and alignment with objectives.
Ruohai Tao (540222281)
Ruohai trained baseline models and recovered time IDs, ensuring data consistency for analysis. He created visualizations and contributed to the report and results slides, highlighting key project outcomes.
You can render the document or run code chunks directly step by step manually by following the UI in your environment (normally run option is above each code chunk for manual execution). It is suggested to have the yaml code jupyter: python3 in the yaml section in the top of the code file in-order to execute in the Jupyter terminal which is more perfomant and flexible. Make sure to install the packages listed in the file requirements.txt via the command pip install -r requirements.txt in your terminal. It is suggested to use a virtual python environment as well. You can change the machine based on your device - either CPU or MPS.
9.2 Denormalization and Cleaning (3.2.2) in detail
From our literature review we found that the Kaggle discussion threads revealed that the prices were scaled by an unknown divisor D and then rounded to the nearest real market tick size (~ $0.01). For every (stock_id, time_id) we, forward fill the 600 snapshots so that every second has a quote. Compute first the differences in the price \(\delta P = price_t - price_{t-1}\) and find the smallest non zero absolute jump; that equals \(\frac{1 \text{tick}}{D}\) then multiply the whole bucket by \(\frac{0.01}{\text{min}(|\delta P_{norm}|)}\). We get the real prices by doing \(P^{\text{real}}_t = D \times P^{\text{norm}}_t\). See below Appendix for more detail.
The resulting 3830 x 120 matrix is our master price panel. A quick histogram of \(\delta P\) by tick show exactly integers only which confirms to us the re-scaling recovered genuine tick units.
9.2.1 Handling the gaps and extreme quotes
Similar to earlier we used forward / backward to impute the remaining holes with the last known quotes; this preserves the micros structure dynamics without fabricating new trends and loosing generality of our method. We exclude a stock if more than 0.05 % of its 1-second snapshots are missing on any trading day (≈ 44 of 88 200). This ceiling keeps the expected gap below 1 s in a 10-minute bucket, ensuring forward-fill imputation cannot materially flatten high-frequency dynamics. To prevent single tick glitches we from exploding volatility estimates we Winsorize each stocks price at the 0.1 % and 99.9% of the quantiles.
This now underpins all subsequent features. To finally recover the chronological order of the time_ids to improve the per bucket RV prediction we embedded each bucket in a 1-D spectral manifold and sort by the leading eigen-coordinate. Because prices evolve almost monotonically intra-day, the leading spectral component monotonises the shuffled ids, effectively restoring the hidden chronology. We validate the approach by applying the same embedding to daily closing prices of the S&P-100 (right panel in Figure 1); the recovered order aligns perfectly with calendar dates, confirming the method’s fidelity.
9.2.2 Assessing Characteristics and trends
After recovering the chronology of the time_ids, for each stock we calculated their RV across the time_id trend and then calculated the average RV per time_id from all the stocks, plotting a Averaged RV against time_id formally $ t ;=;1N^N _{t,i}$. Below is the trend of the data we observed.
Show Code
#| echo: true#| warning: false#| message: false#| output: true#| label: fig-performance#| fig-cap: "Realized Volatility Before and After Transformation"#| fig-align: centerplot_volatility_transformation(avg_rv, stationary_rv, "Log + First Diff")
One challenge we had with the data is testing for stationary when there may be structural changes (breaks). The standard Audgmented Dickey-Fuller (ADF) test assumes the data-generation process is constant over time and often missclasiffies a series if there is a sudden shift as non-stationary. To address this we resorted to using the Zivot-Andrews test, which endogenously estimates and accounts for a single break in either the intercept or trend.
Looking at the above Average RV over time we can see that despite the Zivot-Andrews shows that the series is stationary around a broken trend it the model will see very different variance scales in chronological train vs test split that we will be discussing below. Because of this it was clearly logical to even out the regime shift with appropriate transformations on the data.
9.2.3 Transforming the data
The goal is to compress the high volatility spikes so the model does not treat them as totally out of sample. Hence, a good candidate we chose was log + first differences where \(\epsilon=10^{-8}\) guards against log(0). After transformation, the series oscillates around zero with roughly constant variance, making chronological splitting much more reliable (the model is no longer “blindsided” by a massive spike) while still being stationary.
<networkx.classes.graph.Graph object at 0x32e60e110>
Figure 3: Stock Relation Network (KD-Tree)
9.3 Feature Engineering (3.2) in detail
We need to enrich the transformed_rv_pivot (\(T \times N\)) where T is the number of time buckets and N is the number of stocks with more detail to increase the information gain of the data for the GAT model to learn both temporal pattern and short term fluctuations. In order to achieve this we proceeded with the following features (See appendix for more detail on features):
Own-RV lags: RV often exhibits auto-persistence: a high‐volatility bucket tends to be followed by elevated volatility. So for each stock we included the precious three buckets of transformed RV to capture the short-term persistence of volatility.
Volatility Momentum: Beyond raw persistence, we want to capture changes in the short-term trend—for example, if volatility is accelerating or decelerating, this was done by calculating difference between the average RV over most recent and preceding three buckets
Mean reversion tendency: empirically, volatility often reverts toward a longer‐term mean after extreme moves. We calculate the negative deviation of the current RV from its ten‐bucket rolling average, encoding how strongly each stock’s volatility is “pulled back” toward a longer‐term mean.
Volatility of volatility: Some stocks exhibit wild swings in volatility itself (for instance, jumps around earnings). We take the rolling standard deviation of the last five buckets of RV for each stock, quantifying how erratic or “jittery” the volatility itself has been over that short window
9.4 Correlation Heatmap
Show Code
import torchimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsfrom torch_geometric.data import Datadef extract_and_plot_attention(model, X_test, edge_index, edge_weight, valid_indices, sample_size=20):""" Extract attention weights and create heatmap Parameters: - model: trained GAT model - X_test: test features [T, N, F] - edge_index: edge connections [2, E] - edge_weight: edge weights [E] - valid_indices: valid stock indices - sample_size: number of time steps to sample """ model.eval() T, N, F = X_test.shape attention_matrix = np.zeros((N, N)) count_matrix = np.zeros((N, N))with torch.no_grad():for t inrange(min(sample_size, T)):# Create graph for this time step - move to same device as model graph = Data( x=torch.tensor(X_test[t], dtype=torch.float).to(next(model.parameters()).device), edge_index=edge_index.to(next(model.parameters()).device), edge_weight=edge_weight.to(next(model.parameters()).device) )try:# Forward pass with attention extraction# This modifies the first layer to return attention x = graph.x h, (edge_idx, attention_weights) = model.conv1(x, graph.edge_index, return_attention_weights=True)# Convert to numpy edge_idx = edge_idx.cpu().numpy() attention_weights = attention_weights.cpu().numpy()# Handle multi-head attention (average across heads)iflen(attention_weights.shape) >1: attention_weights = attention_weights.mean(axis=-1)# Fill attention matrixfor i, (src, tgt) inenumerate(edge_idx.T):if src < N and tgt < N: attention_matrix[src, tgt] += attention_weights[i] count_matrix[src, tgt] +=1exceptExceptionas e:print(f"Error at time step {t}: {e}")continue# Average attention weights mask = count_matrix >0 attention_matrix[mask] = attention_matrix[mask] / count_matrix[mask]# Create heatmap plt.figure(figsize=(12, 10))# Use subset for cleaner visualization if too many stocksif N >30: subset_size =30 subset_idx = np.random.choice(N, subset_size, replace=False) subset_matrix = attention_matrix[np.ix_(subset_idx, subset_idx)] subset_labels = [f'S{valid_indices[i]}'for i in subset_idx] title =f"GAT Attention Matrix (Random {subset_size} stocks)"else: subset_matrix = attention_matrix subset_labels = [f'S{valid_indices[i]}'for i inrange(N)] title ="GAT Attention Matrix (All stocks)"# Plot heatmap sns.heatmap(subset_matrix, xticklabels=subset_labels, yticklabels=subset_labels, cmap='Blues', square=True, cbar_kws={'label': 'Attention Weight'}, linewidths=0.1) plt.title(title, fontsize=14, fontweight='bold') plt.xlabel('Target Stock (receives attention)') plt.ylabel('Source Stock (gives attention)') plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() plt.show()# Print some statistics non_zero = np.count_nonzero(attention_matrix) max_attention = np.max(attention_matrix) mean_attention = np.mean(attention_matrix[attention_matrix >0])print(f"\nAttention Statistics:")print(f"Non-zero attention weights: {non_zero}")print(f"Max attention weight: {max_attention:.4f}")print(f"Mean attention weight: {mean_attention:.4f}")return attention_matrix
Show Code
import torchimport numpy as npimport matplotlib.pyplot as pltimport seaborn as snsfrom torch_geometric.data import Datadef load_and_visualize_attention(edge_index, edge_weight, valid_indices, neighbor_indices, rv_pivot, device):""" Parameters: - edge_index: your pre-computed edge connections - edge_weight: your pre-computed edge weights - valid_indices: your valid stock indices - neighbor_indices: your neighbor matrix - rv_pivot: your realized volatility pivot table - device: your torch device """ checkpoint = torch.load('best_gat_model.pt', map_location=device) model = ImprovedVolatilityGAT( in_feats=7, # 6 temporal + 1 neighbor feature hidden=32, heads=2, dropout=0.2 ).to(device)# Load weights model.load_state_dict(checkpoint['model_state_dict'])print(f"Model loaded with validation loss: {checkpoint['val_loss']:.6f}")print(f"Model device: {next(model.parameters()).device}")# Prepare test data (use the same approach as in your evaluation) total_time_steps =len(rv_pivot) test_start =int(total_time_steps *0.8) test_rv_pivot = rv_pivot.iloc[test_start:]# Create test features (same as your training pipeline) test_transformed_rv = create_stationary_features_fixed(test_rv_pivot) X_test_temporal = build_initial_features(test_transformed_rv) X_test_neighbor, _ = build_neighbour_feats(test_transformed_rv, neighbor_indices) X_test = np.concatenate([X_test_temporal, X_test_neighbor], axis=2)print(f"Test data shape: {X_test.shape}")# Make sure edge tensors are on the correct device edge_index = edge_index.to(device) edge_weight = edge_weight.to(device)print(f"Edge tensors moved to device: {edge_index.device}")# Extract and visualize attention attention_matrix = extract_and_plot_attention( model, X_test, edge_index, edge_weight, valid_indices, sample_size=20 )return model, attention_matrixmodel, attention_matrix = load_and_visualize_attention(edge_index, edge_weight, valid_indices, neighbor_indices, rv_pivot, device)
Model loaded with validation loss: 17.679580
Model device: cpu
Input matrix: 766 time points × 112 stocks
Initial features: (765, 112, 6) (T × N × 6 features)
Features created: (765, 112, 1) (T × N × 1 features)
Test data shape: (765, 112, 7)
Edge tensors moved to device: cpu
Attention Statistics:
Non-zero attention weights: 448
Max attention weight: 1.0000
Mean attention weight: 0.2500
Figure 4: Heatmap of inter-stock correlations
9.5 GAT Training Terminology and Details
Table for GAT terminology
References
Andersen, T.G. et al. (2003) “Modeling and forecasting realized volatility,”Econometrica, 71(2), pp. 579–625.
Cont, R. (2001) “Empirical properties of asset returns: Stylized facts and statistical issues,”Quantitative Finance, 1(2), pp. 223–236.
Corsi, F. (2009) “A simple approximate long-memory model of realized volatility,”Journal of Financial Econometrics, 7(2), pp. 174–196.
Velickovic, P. et al. (2018) “Graph attention networks,” in International conference on learning representations (ICLR).
Yin, H. (2023) “Enhancing directional accuracy in stock closing price value prediction using a direction-integrated MSE loss function,”International Journal of Computational Science and Engineering, 28(1), pp. 1–12. Available at: https://www.scitepress.org/Papers/2023/128102/128102.pdf.
Zhang, Y. et al. (2022) “Stock price movement prediction with graph attention networks,”Expert Systems with Applications, 191, p. 116367.